#include "platform.h"
#include "xbasic_types.h"
#include "xparameters.h" // Contains definitions for all peripherals
#include "xhls_multiplier.h" // Contains hls_multiplier macros and functions

#include "AEStables.h"

// we will use the Base Address of the RTL_MULTIPLIER
Xuint32 *baseaddr_rtl_multiplier =
		(Xuint32 *) XPAR_RTL_MULTIPLIER_0_S00_AXI_BASEADDR;

// These are AES constants for AES 128, 192, 256
const unsigned short Nb = 4; // columns (can be changed to a larger number in the future)
const unsigned short rows = 4; // rows
const unsigned short stt_lng = Nb * rows; // state length

// Rijndael key schedule
// https://en.wikipedia.org/wiki/Rijndael_key_schedule
const unsigned char rcon[256] = { 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20,
		0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc,
		0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91,
		0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33,
		0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04,
		0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a,
		0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa,
		0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25,
		0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d,
		0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8,
		0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4,
		0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61,
		0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74,
		0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b,
		0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97,
		0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4,
		0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83,
		0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20,
		0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc,
		0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91,
		0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33,
		0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d };

void KeyExpansionCore(unsigned char* in4, unsigned char i) {
	// RotWord rotates left
	// SubWord substitutes with S - Box value
	unsigned char t = in4[0];
	in4[0] = s_box[in4[1]];
	in4[1] = s_box[in4[2]];
	in4[2] = s_box[in4[3]];
	in4[3] = s_box[t];
	// RCon (round constant) 1st byte XOR rcon
	in4[0] = in4[0] ^ rcon[i];
}

void SubWord(unsigned char* in4) {
	// SubWord substitutes with S - Box value
	in4[0] = s_box[in4[0]];
	in4[1] = s_box[in4[1]];
	in4[2] = s_box[in4[2]];
	in4[3] = s_box[in4[3]];
}

void KeyExpansion(unsigned char* inputKey, unsigned short Nk,
		unsigned char* expandedKey) {
	unsigned short Nr = (Nk > Nb) ? Nk + 6 : Nb + 6; // = 10, 12 or 14 rounds
	// Copy the inputKey at the beginning of expandedKey
	for (unsigned short i = 0; i < Nk * rows; i++) {
		expandedKey[i] = inputKey[i];
	}

	int main() {
		init_platform();

		////////////////////////////////////////////////////////////////////////////////////////
		// RTL MULTIPLIER TEST
		xil_printf("Performing a test of the RTL_MULTIPLIER... \n\r");

		// Write multiplier inputs to register 0
		*(baseaddr_rtl_multiplier + 0) = 0x00020003;
		xil_printf("Wrote to register 0: 0x%08x \n\r",
				*(baseaddr_rtl_multiplier + 0));

		// Read multiplier output from register 1
		xil_printf("Read from register 1: 0x%08x \n\r",
				*(baseaddr_rtl_multiplier + 1));

		xil_printf("End of test RTL_MULTIPLIER \n\n\r");

		////////////////////////////////////////////////////////////////////////////////////////
		// HLS MULTIPLIER TEST
		xil_printf("Performing a test of the HLS_MULTIPLIER... \n\r");

		// Vivado HLS generates
		int status;
		// Create hls_multiplier pointer
		XHls_multiplier do_hls_multiplier;
		XHls_multiplier_Config *do_hls_multiplier_cfg;
		do_hls_multiplier_cfg = XHls_multiplier_LookupConfig(
		XPAR_HLS_MULTIPLIER_0_DEVICE_ID);

		if (!do_hls_multiplier_cfg) {
			xil_printf(
					"Error loading configuration for do_hls_multiplier_cfg \n\r");
		}

		status = XHls_multiplier_CfgInitialize(&do_hls_multiplier,
				do_hls_multiplier_cfg);
		if (status != XST_SUCCESS) {
			xil_printf("Error initializing for do_hls_multiplier \n\r");
		}

		XHls_multiplier_Initialize(&do_hls_multiplier,
		XPAR_HLS_MULTIPLIER_0_DEVICE_ID); // this is optional in this case

		unsigned short int a, b;
		unsigned int p;

		a = 2;
		b = 3;
		p = 0;

		// Write multiplier inputs to register 0
		XHls_multiplier_Set_a(&do_hls_multiplier, a);
		XHls_multiplier_Set_b(&do_hls_multiplier, b);
		xil_printf("Write a: 0x%08x \n\r", a);
		xil_printf("Write b: 0x%08x \n\r", b);

		// Start hls_multiplier
		XHls_multiplier_Start(&do_hls_multiplier);
		xil_printf("Started hls_multiplier \n\r");

		// Wait until it's done (optional here)
		while (!XHls_multiplier_IsDone(&do_hls_multiplier))
			;

		// Get hls_multiplier returned value
		p = XHls_multiplier_Get_return(&do_hls_multiplier);

		xil_printf("Read p: 0x%08x \n\r", p);

		xil_printf("End of test HLS_MULTIPLIER \n\n\r");

		////////////////////////////////////////////////////////////////////////////////////////
		// HLS AES TEST

		// Variables
		unsigned short byGen = Nk * rows;
		unsigned short rconIdx = 1;
		unsigned char temp[rows];

		cleanup_platform();
		return 0;
	}
